Statistical learning: classification and cross-validation

MACS 30500 University of Chicago

Should I Have a Cookie?

Interpreting a decision tree

A more complex tree

## 
## Model formula:
## Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked
## 
## Fitted party:
## [1] root
## |   [2] Sex in female
## |   |   [3] Pclass <= 2: Survived (n = 170, err = 5.3%)
## |   |   [4] Pclass > 2
## |   |   |   [5] Fare <= 23.25: Survived (n = 117, err = 41.0%)
## |   |   |   [6] Fare > 23.25: Died (n = 27, err = 11.1%)
## |   [7] Sex in male
## |   |   [8] Pclass <= 1
## |   |   |   [9] Age <= 52: Died (n = 98, err = 41.8%)
## |   |   |   [10] Age > 52: Died (n = 24, err = 16.7%)
## |   |   [11] Pclass > 1
## |   |   |   [12] Age <= 9
## |   |   |   |   [13] Pclass <= 2: Survived (n = 12, err = 25.0%)
## |   |   |   |   [14] Pclass > 2: Died (n = 29, err = 31.0%)
## |   |   |   [15] Age > 9: Died (n = 414, err = 11.1%)
## 
## Number of inner nodes:    7
## Number of terminal nodes: 8
## [1] 0.1829405

A more complexier tree

## 
## Model formula:
## Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked
## 
## Fitted party:
## [1] root
## |   [2] Sex in female
## |   |   [3] Pclass <= 2: Survived (n = 170, err = 5.3%)
## |   |   [4] Pclass > 2
## |   |   |   [5] Fare <= 23.25
## |   |   |   |   [6] Age <= 16
## |   |   |   |   |   [7] Embarked in C, S: Survived (n = 22, err = 31.8%)
## |   |   |   |   |   [8] Embarked in Q: Survived (n = 7, err = 28.6%)
## |   |   |   |   [9] Age > 16: Survived (n = 88, err = 44.3%)
## |   |   |   [10] Fare > 23.25: Died (n = 27, err = 11.1%)
## |   [11] Sex in male
## |   |   [12] Pclass <= 1
## |   |   |   [13] Age <= 52: Died (n = 94, err = 43.6%)
## |   |   |   [14] Age > 52: Died (n = 28, err = 14.3%)
## |   |   [15] Pclass > 1
## |   |   |   [16] Age <= 9
## |   |   |   |   [17] Pclass <= 2: Survived (n = 10, err = 10.0%)
## |   |   |   |   [18] Pclass > 2
## |   |   |   |   |   [19] SibSp <= 1: Died (n = 19, err = 42.1%)
## |   |   |   |   |   [20] SibSp > 1: Died (n = 15, err = 6.7%)
## |   |   |   [21] Age > 9: Died (n = 411, err = 11.2%)
## 
## Number of inner nodes:    10
## Number of terminal nodes: 11
## [1] 0.1806958

Benefits/drawbacks to decision trees

  • Easy to explain
  • Easy to interpret/visualize
  • Good for qualitative predictors
  • Lower accuracy rates
  • Non-robust

Random forests

Sampling with replacement

(numbers <- seq(from = 1, to = 10))
##  [1]  1  2  3  4  5  6  7  8  9 10
# sample without replacement
rerun(5, sample(numbers, replace = FALSE))
## [[1]]
##  [1]  6  4  1 10  9  7  5  2  3  8
## 
## [[2]]
##  [1]  9  8  7  1  4  5  3 10  6  2
## 
## [[3]]
##  [1]  2  4  7  1 10  8  3  5  6  9
## 
## [[4]]
##  [1]  4  6  3  1  7 10  5  8  2  9
## 
## [[5]]
##  [1]  8  6  7  5  9  3 10  1  4  2
# sample with replacement
rerun(5, sample(numbers, replace = TRUE))
## [[1]]
##  [1]  5  4  2  3 10  1  5  3  5  8
## 
## [[2]]
##  [1]  8 10  2  9  9 10  9  1  5  3
## 
## [[3]]
##  [1]  6  3  1 10  4  3  7  8  1  7
## 
## [[4]]
##  [1]  2  3  7  2  4 10  8  5  4  8
## 
## [[5]]
##  [1]  3  3  4  9  9  1 10  2  1  7

Random forests

  • Bootstrapping
  • Reduces variance
  • Bagging
  • Random forest
    • Reliability

Estimating statistical models using caret

  • Not part of tidyverse (yet)
  • Aggregator of hundreds of statistical learning algorithms
  • Provides a single unified interface to disparate range of functions
    • Similar to scikit-learn for Python

train()

library(caret)

titanic_clean <- titanic %>%
  filter(!is.na(Survived), !is.na(Age))

caret_glm <- train(Survived ~ Age, data = titanic_clean,
                   method = "glm",
                   family = binomial,
                   trControl = trainControl(method = "none"))
summary(caret_glm)
## 
## Call:
## NULL
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.1488  -1.0361  -0.9544   1.3159   1.5908  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)  
## (Intercept) -0.05672    0.17358  -0.327   0.7438  
## Age         -0.01096    0.00533  -2.057   0.0397 *
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 964.52  on 713  degrees of freedom
## Residual deviance: 960.23  on 712  degrees of freedom
## AIC: 964.23
## 
## Number of Fisher Scoring iterations: 4

Estimating a random forest

age_sex_rf <- train(Survived ~ Age + Sex, data = titanic_rf_data,
                   method = "rf",
                   ntree = 200,
                   trControl = trainControl(method = "oob"))
## note: only 1 unique complexity parameters in default grid. Truncating the grid to 1 .
age_sex_rf
## Random Forest 
## 
## 714 samples
##   2 predictor
##   2 classes: 'Died', 'Survived' 
## 
## No pre-processing
## Resampling results:
## 
##   Accuracy   Kappa    
##   0.7507003  0.4734426
## 
## Tuning parameter 'mtry' was held constant at a value of 2

Structure of train() object

## List of 24
##  $ method      : chr "rf"
##  $ modelInfo   :List of 15
##  $ modelType   : chr "Classification"
##  $ results     :'data.frame':    1 obs. of  3 variables:
##  $ pred        : NULL
##  $ bestTune    :'data.frame':    1 obs. of  1 variable:
##  $ call        : language train.formula(form = Survived ~ Age + Sex, data = titanic_rf_data,      method = "rf", ntree = 200, trControl = t| __truncated__
##  $ dots        :List of 1
##  $ metric      : chr "Accuracy"
##  $ control     :List of 26
##  $ finalModel  :List of 23
##   ..- attr(*, "class")= chr "randomForest"
##  $ preProcess  : NULL
##  $ trainingData:Classes 'tbl_df', 'tbl' and 'data.frame':    714 obs. of  3 variables:
##  $ resample    : NULL
##  $ resampledCM : NULL
##  $ perfNames   : chr [1:2] "Accuracy" "Kappa"
##  $ maximize    : logi TRUE
##  $ yLimits     : NULL
##  $ times       :List of 3
##  $ levels      : chr [1:2] "Died" "Survived"
##   ..- attr(*, "ordered")= logi FALSE
##  $ terms       :Classes 'terms', 'formula'  language Survived ~ Age + Sex
##   .. ..- attr(*, "variables")= language list(Survived, Age, Sex)
##   .. ..- attr(*, "factors")= int [1:3, 1:2] 0 1 0 0 0 1
##   .. .. ..- attr(*, "dimnames")=List of 2
##   .. ..- attr(*, "term.labels")= chr [1:2] "Age" "Sex"
##   .. ..- attr(*, "order")= int [1:2] 1 1
##   .. ..- attr(*, "intercept")= int 1
##   .. ..- attr(*, "response")= int 1
##   .. ..- attr(*, ".Environment")=<environment: R_GlobalEnv> 
##   .. ..- attr(*, "predvars")= language list(Survived, Age, Sex)
##   .. ..- attr(*, "dataClasses")= Named chr [1:3] "factor" "numeric" "factor"
##   .. .. ..- attr(*, "names")= chr [1:3] "Survived" "Age" "Sex"
##  $ coefnames   : chr [1:2] "Age" "Sexmale"
##  $ contrasts   :List of 1
##  $ xlevels     :List of 1
##  - attr(*, "class")= chr [1:2] "train" "train.formula"

Model statistics

## 
## Call:
##  randomForest(x = x, y = y, ntree = 200, mtry = param$mtry) 
##                Type of random forest: classification
##                      Number of trees: 200
## No. of variables tried at each split: 2
## 
##         OOB estimate of  error rate: 24.23%
## Confusion matrix:
##          Died Survived class.error
## Died      357       67   0.1580189
## Survived  106      184   0.3655172

Results of a single tree

##     left daughter right daughter split var split point status prediction
## 1               2              3   Sexmale       0.500      1       <NA>
## 2               4              5       Age       3.500      1       <NA>
## 3               6              7       Age       5.500      1       <NA>
## 4               8              9       Age       1.375      1       <NA>
## 5              10             11       Age      32.250      1       <NA>
## 6              12             13       Age       0.960      1       <NA>
## 7              14             15       Age      77.000      1       <NA>
## 8               0              0      <NA>       0.000     -1   Survived
## 9              16             17       Age       2.500      1       <NA>
## 10             18             19       Age      24.500      1       <NA>
## 11             20             21       Age      38.500      1       <NA>
## 12              0              0      <NA>       0.000     -1   Survived
## 13             22             23       Age       2.000      1       <NA>
## 14             24             25       Age      50.500      1       <NA>
## 15              0              0      <NA>       0.000     -1   Survived
## 16              0              0      <NA>       0.000     -1       Died
## 17              0              0      <NA>       0.000     -1       Died
## 18             26             27       Age       5.500      1       <NA>
## 19             28             29       Age      25.500      1       <NA>
## 20              0              0      <NA>       0.000     -1   Survived
## 21             30             31       Age      50.500      1       <NA>
## 22              0              0      <NA>       0.000     -1   Survived
## 23             32             33       Age       3.500      1       <NA>
## 24             34             35       Age      47.500      1       <NA>
## 25             36             37       Age      56.500      1       <NA>
## 26              0              0      <NA>       0.000     -1   Survived
## 27             38             39       Age      12.000      1       <NA>
## 28              0              0      <NA>       0.000     -1       Died
## 29             40             41       Age      30.250      1       <NA>
## 30             42             43       Age      49.500      1       <NA>
## 31              0              0      <NA>       0.000     -1   Survived
## 32              0              0      <NA>       0.000     -1   Survived
## 33              0              0      <NA>       0.000     -1   Survived
## 34             44             45       Age      45.250      1       <NA>
## 35             46             47       Age      48.500      1       <NA>
## 36             48             49       Age      55.500      1       <NA>
## 37              0              0      <NA>       0.000     -1       Died
## 38             50             51       Age       8.500      1       <NA>
## 39             52             53       Age      19.500      1       <NA>
## 40             54             55       Age      28.500      1       <NA>
## 41             56             57       Age      31.500      1       <NA>
## 42             58             59       Age      41.500      1       <NA>
## 43              0              0      <NA>       0.000     -1       Died
## 44             60             61       Age      43.500      1       <NA>
## 45              0              0      <NA>       0.000     -1       Died
## 46              0              0      <NA>       0.000     -1   Survived
## 47             62             63       Age      49.500      1       <NA>
## 48              0              0      <NA>       0.000     -1       Died
## 49              0              0      <NA>       0.000     -1       Died
## 50             64             65       Age       6.500      1       <NA>
## 51              0              0      <NA>       0.000     -1       Died
## 52             66             67       Age      18.500      1       <NA>
## 53             68             69       Age      20.500      1       <NA>
## 54             70             71       Age      27.500      1       <NA>
## 55             72             73       Age      29.500      1       <NA>
## 56              0              0      <NA>       0.000     -1       Died
## 57              0              0      <NA>       0.000     -1   Survived
## 58             74             75       Age      39.500      1       <NA>
## 59             76             77       Age      46.000      1       <NA>
## 60             78             79       Age      32.250      1       <NA>
## 61             80             81       Age      44.500      1       <NA>
## 62              0              0      <NA>       0.000     -1       Died
## 63              0              0      <NA>       0.000     -1       Died
## 64              0              0      <NA>       0.000     -1       Died
## 65             82             83       Age       7.500      1       <NA>
## 66             84             85       Age      13.500      1       <NA>
## 67              0              0      <NA>       0.000     -1   Survived
## 68              0              0      <NA>       0.000     -1       Died
## 69             86             87       Age      22.500      1       <NA>
## 70             88             89       Age      26.500      1       <NA>
## 71              0              0      <NA>       0.000     -1       Died
## 72              0              0      <NA>       0.000     -1   Survived
## 73              0              0      <NA>       0.000     -1   Survived
## 74              0              0      <NA>       0.000     -1   Survived
## 75             90             91       Age      40.500      1       <NA>
## 76              0              0      <NA>       0.000     -1   Survived
## 77             92             93       Age      47.500      1       <NA>
## 78             94             95       Age      30.500      1       <NA>
## 79             96             97       Age      35.500      1       <NA>
## 80              0              0      <NA>       0.000     -1       Died
## 81              0              0      <NA>       0.000     -1   Survived
## 82              0              0      <NA>       0.000     -1   Survived
## 83              0              0      <NA>       0.000     -1   Survived
## 84              0              0      <NA>       0.000     -1   Survived
## 85             98             99       Age      14.500      1       <NA>
## 86            100            101       Age      21.500      1       <NA>
## 87            102            103       Age      23.500      1       <NA>
## 88              0              0      <NA>       0.000     -1       Died
## 89              0              0      <NA>       0.000     -1   Survived
## 90              0              0      <NA>       0.000     -1   Survived
## 91              0              0      <NA>       0.000     -1   Survived
## 92              0              0      <NA>       0.000     -1       Died
## 93              0              0      <NA>       0.000     -1   Survived
## 94            104            105       Age      24.500      1       <NA>
## 95            106            107       Age      31.500      1       <NA>
## 96              0              0      <NA>       0.000     -1       Died
## 97            108            109       Age      41.500      1       <NA>
## 98              0              0      <NA>       0.000     -1   Survived
## 99            110            111       Age      16.500      1       <NA>
## 100             0              0      <NA>       0.000     -1   Survived
## 101             0              0      <NA>       0.000     -1   Survived
## 102             0              0      <NA>       0.000     -1   Survived
## 103             0              0      <NA>       0.000     -1   Survived
## 104           112            113       Age      21.500      1       <NA>
## 105           114            115       Age      27.500      1       <NA>
## 106             0              0      <NA>       0.000     -1       Died
## 107             0              0      <NA>       0.000     -1       Died
## 108           116            117       Age      40.250      1       <NA>
## 109           118            119       Age      42.500      1       <NA>
## 110           120            121       Age      15.500      1       <NA>
## 111           122            123       Age      17.500      1       <NA>
## 112           124            125       Age       9.500      1       <NA>
## 113           126            127       Age      22.500      1       <NA>
## 114           128            129       Age      26.500      1       <NA>
## 115           130            131       Age      28.750      1       <NA>
## 116           132            133       Age      38.500      1       <NA>
## 117             0              0      <NA>       0.000     -1       Died
## 118             0              0      <NA>       0.000     -1       Died
## 119             0              0      <NA>       0.000     -1       Died
## 120             0              0      <NA>       0.000     -1   Survived
## 121             0              0      <NA>       0.000     -1   Survived
## 122             0              0      <NA>       0.000     -1   Survived
## 123             0              0      <NA>       0.000     -1   Survived
## 124           134            135       Age       7.500      1       <NA>
## 125           136            137       Age      15.500      1       <NA>
## 126             0              0      <NA>       0.000     -1       Died
## 127           138            139       Age      23.250      1       <NA>
## 128           140            141       Age      25.500      1       <NA>
## 129             0              0      <NA>       0.000     -1       Died
## 130           142            143       Age      28.250      1       <NA>
## 131           144            145       Age      29.500      1       <NA>
## 132           146            147       Age      36.250      1       <NA>
## 133           148            149       Age      39.500      1       <NA>
## 134             0              0      <NA>       0.000     -1       Died
## 135           150            151       Age       8.500      1       <NA>
## 136             0              0      <NA>       0.000     -1       Died
## 137           152            153       Age      20.250      1       <NA>
## 138             0              0      <NA>       0.000     -1       Died
## 139           154            155       Age      23.750      1       <NA>
## 140             0              0      <NA>       0.000     -1       Died
## 141             0              0      <NA>       0.000     -1       Died
## 142             0              0      <NA>       0.000     -1       Died
## 143             0              0      <NA>       0.000     -1       Died
## 144             0              0      <NA>       0.000     -1       Died
## 145             0              0      <NA>       0.000     -1       Died
## 146             0              0      <NA>       0.000     -1       Died
## 147           156            157       Age      37.500      1       <NA>
## 148             0              0      <NA>       0.000     -1       Died
## 149             0              0      <NA>       0.000     -1       Died
## 150             0              0      <NA>       0.000     -1       Died
## 151             0              0      <NA>       0.000     -1   Survived
## 152           158            159       Age      18.500      1       <NA>
## 153           160            161       Age      20.750      1       <NA>
## 154             0              0      <NA>       0.000     -1       Died
## 155             0              0      <NA>       0.000     -1       Died
## 156             0              0      <NA>       0.000     -1       Died
## 157             0              0      <NA>       0.000     -1       Died
## 158           162            163       Age      17.500      1       <NA>
## 159           164            165       Age      19.500      1       <NA>
## 160             0              0      <NA>       0.000     -1       Died
## 161             0              0      <NA>       0.000     -1       Died
## 162           166            167       Age      16.500      1       <NA>
## 163             0              0      <NA>       0.000     -1       Died
## 164             0              0      <NA>       0.000     -1       Died
## 165             0              0      <NA>       0.000     -1       Died
## 166             0              0      <NA>       0.000     -1       Died
## 167             0              0      <NA>       0.000     -1       Died

Variable importance

Exercise: depression and voting

Resampling methods

  • Evaluating model fit/predictive power
  • How to avoid overfitting the data

Validation set

  • Randomly split data into two distinct sets
    • Training set
    • Test set
  • Train model on training set
  • Evaluate fit on test set

Regression

Mean squared error

\[MSE = \frac{1}{n} \sum_{i = 1}^{n}{(y_i - \hat{f}(x_i))^2}\]

  • \(y_i =\) the observed response value for the \(i\)th observation
  • \(\hat{f}(x_i) =\) the predicted response value for the \(i\)th observation given by \(\hat{f}\)
  • \(n =\) the total number of observations

Split data

set.seed(1234)

auto_split <- initial_split(data = Auto, prop = 0.5)
auto_train <- training(auto_split)
auto_test <- testing(auto_split)

Train model

auto_lm <- glm(mpg ~ horsepower, data = auto_train)
summary(auto_lm)
## 
## Call:
## glm(formula = mpg ~ horsepower, data = auto_train)
## 
## Deviance Residuals: 
##      Min        1Q    Median        3Q       Max  
## -13.7105   -3.4442   -0.5342    2.6256   15.1015  
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 40.057910   1.054798   37.98   <2e-16 ***
## horsepower  -0.157604   0.009402  -16.76   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for gaussian family taken to be 24.80151)
## 
##     Null deviance: 11780.6  on 195  degrees of freedom
## Residual deviance:  4811.5  on 194  degrees of freedom
## AIC: 1189.6
## 
## Number of Fisher Scoring iterations: 2
(train_mse <- augment(auto_lm, newdata = auto_train) %>%
  mutate(.resid = mpg - .fitted,
         .resid2 = .resid ^ 2) %$%
  mean(.resid2))
## [1] 24.54843

Test model

(test_mse <- augment(auto_lm, newdata = auto_test) %>%
  mutate(.resid = mpg - .fitted,
         .resid2 = .resid ^ 2) %$%
  mean(.resid2))
## [1] 23.38243

Compare models

Classification

survive_age_woman_x <- glm(Survived ~ Age * Sex, data = titanic,
                           family = binomial)
summary(survive_age_woman_x)
## 
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = titanic)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.9401  -0.7136  -0.5883   0.7626   2.2455  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)   
## (Intercept)  0.59380    0.31032   1.913  0.05569 . 
## Age          0.01970    0.01057   1.863  0.06240 . 
## Sexmale     -1.31775    0.40842  -3.226  0.00125 **
## Age:Sexmale -0.04112    0.01355  -3.034  0.00241 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 964.52  on 713  degrees of freedom
## Residual deviance: 740.40  on 710  degrees of freedom
##   (177 observations deleted due to missingness)
## AIC: 748.4
## 
## Number of Fisher Scoring iterations: 4

Test error rate

# split the data into training and validation sets
titanic_split <- initial_split(data = titanic, prop = 0.5)

# fit model to training data
train_model <- glm(Survived ~ Age * Sex, data = training(titanic_split),
                   family = binomial)
summary(train_model)
## 
## Call:
## glm(formula = Survived ~ Age * Sex, family = binomial, data = training(titanic_split))
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.9374  -0.7041  -0.5866   0.7644   2.1918  
## 
## Coefficients:
##             Estimate Std. Error z value Pr(>|z|)  
## (Intercept)  0.58906    0.41752   1.411   0.1583  
## Age          0.01968    0.01414   1.391   0.1642  
## Sexmale     -1.42528    0.55970  -2.546   0.0109 *
## Age:Sexmale -0.03806    0.01829  -2.080   0.0375 *
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 485.10  on 358  degrees of freedom
## Residual deviance: 370.14  on 355  degrees of freedom
##   (87 observations deleted due to missingness)
## AIC: 378.14
## 
## Number of Fisher Scoring iterations: 4
# calculate predictions using validation set
x_test_accuracy <- augment(train_model, newdata = testing(titanic_split)) %>% 
  as_tibble() %>%
  mutate(pred = logit2prob(.fitted),
         pred = as.numeric(pred > .5))

# calculate test error rate
mean(x_test_accuracy$Survived != x_test_accuracy$pred, na.rm = TRUE)
## [1] 0.2225352

Drawbacks to validation sets

Leave-one-out cross-validation

\[CV_{(n)} = \frac{1}{n} \sum_{i = 1}^{n}{MSE_i}\]

  • Extension of validation set to repeatedly split data and average results
  • Minimizes bias of estimated error rate
  • Low variance
  • Highly computationally intensive

rsample::loo_cv()

loocv_data <- loo_cv(Auto)
loocv_data
## # Leave-one-out cross-validation 
## # A tibble: 392 x 2
##    splits       id        
##    <list>       <chr>     
##  1 <S3: rsplit> Resample1 
##  2 <S3: rsplit> Resample2 
##  3 <S3: rsplit> Resample3 
##  4 <S3: rsplit> Resample4 
##  5 <S3: rsplit> Resample5 
##  6 <S3: rsplit> Resample6 
##  7 <S3: rsplit> Resample7 
##  8 <S3: rsplit> Resample8 
##  9 <S3: rsplit> Resample9 
## 10 <S3: rsplit> Resample10
## # ... with 382 more rows

Splits

first_resample <- loocv_data$splits[[1]]
first_resample
## <391/1/392>
training(first_resample)
## # A tibble: 391 x 9
##      mpg cylinders displacement horsepower weight acceleration  year origin
##    <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
##  1    18         8          307        130   3504         12      70      1
##  2    15         8          350        165   3693         11.5    70      1
##  3    18         8          318        150   3436         11      70      1
##  4    16         8          304        150   3433         12      70      1
##  5    17         8          302        140   3449         10.5    70      1
##  6    15         8          429        198   4341         10      70      1
##  7    14         8          454        220   4354          9      70      1
##  8    14         8          440        215   4312          8.5    70      1
##  9    14         8          455        225   4425         10      70      1
## 10    15         8          390        190   3850          8.5    70      1
## # ... with 381 more rows, and 1 more variable: name <fct>
assessment(first_resample)
## # A tibble: 1 x 9
##     mpg cylinders displacement horsepower weight acceleration  year origin
##   <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
## 1    14         8          318        150   4457         13.5    74      1
## # ... with 1 more variable: name <fct>

Holdout results

  1. Obtain the analysis data set (i.e. the \(n-1\) training set)
  2. Fit a linear regression model
  3. Predict the test data (also known as the assessment data, the \(1\) test set) using the broom package
  4. Determine the MSE for each sample

Holdout results

holdout_results <- function(splits) {
  # Fit the model to the n-1
  mod <- glm(mpg ~ horsepower, data = analysis(splits))
  
  # Save the heldout observation
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = holdout) %>%
    # calculate residuals for future use
    mutate(.resid = mpg - .fitted)
  
  # Return the assessment data set with the additional columns
  res
}

Holdout results

holdout_results(loocv_data$splits[[1]])
## # A tibble: 1 x 12
##     mpg cylinders displacement horsepower weight acceleration  year origin
##   <dbl>     <dbl>        <dbl>      <dbl>  <dbl>        <dbl> <dbl>  <dbl>
## 1    14         8          318        150   4457         13.5    74      1
## # ... with 4 more variables: name <fct>, .fitted <dbl>, .se.fit <dbl>,
## #   .resid <dbl>
loocv_data$results <- map(loocv_data$splits, holdout_results)
loocv_data$mse <- map_dbl(loocv_data$results, ~ mean(.$.resid ^ 2))
loocv_data
## # Leave-one-out cross-validation 
## # A tibble: 392 x 4
##    splits       id         results               mse
##    <list>       <chr>      <list>              <dbl>
##  1 <S3: rsplit> Resample1  <tibble [1 × 12]>  5.17  
##  2 <S3: rsplit> Resample2  <tibble [1 × 12]>  1.77  
##  3 <S3: rsplit> Resample3  <tibble [1 × 12]>  2.07  
##  4 <S3: rsplit> Resample4  <tibble [1 × 12]>  2.40  
##  5 <S3: rsplit> Resample5  <tibble [1 × 12]> 14.8   
##  6 <S3: rsplit> Resample6  <tibble [1 × 12]>  2.77  
##  7 <S3: rsplit> Resample7  <tibble [1 × 12]> 56.9   
##  8 <S3: rsplit> Resample8  <tibble [1 × 12]> 22.6   
##  9 <S3: rsplit> Resample9  <tibble [1 × 12]>  0.0680
## 10 <S3: rsplit> Resample10 <tibble [1 × 12]> 50.1   
## # ... with 382 more rows
loocv_data %>%
  summarize(mse = mean(mse))
## # Leave-one-out cross-validation 
## # A tibble: 1 x 1
##     mse
##   <dbl>
## 1  24.2

Compare polynomial terms

## # A tibble: 5 x 2
##   terms mse_loocv
##   <int>     <dbl>
## 1     1      24.2
## 2     2      19.2
## 3     3      19.3
## 4     4      19.4
## 5     5      19.0

LOOCV in classification

# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
  # Fit the model to the n-1
  mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
             family = binomial)
  
  # Save the heldout observation
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = assessment(splits)) %>% 
    as_tibble() %>%
    mutate(pred = logit2prob(.fitted),
           pred = as.numeric(pred > .5))

  # Return the assessment data set with the additional columns
  res
}

titanic_loocv <- loo_cv(titanic) %>%
  mutate(results = map(splits, holdout_results),
         error_rate = map_dbl(results, ~ mean(.$Survived != .$pred, na.rm = TRUE)))
mean(titanic_loocv$error_rate, na.rm = TRUE)
## [1] 0.219888

Exercise: LOOCV in linear regression

\(k\)-fold cross-validation

\[CV_{(k)} = \frac{1}{k} \sum_{i = 1}^{k}{MSE_i}\]

  • Split data into \(k\) folds
  • Repeat training/test process for each fold
  • LOOCV: \(k=n\)

k-fold CV in linear regression

# modified function to estimate model with varying highest order polynomial
holdout_results <- function(splits, i) {
  # Fit the model to the training set
  mod <- glm(mpg ~ poly(horsepower, i), data = analysis(splits))
  
  # Save the heldout observations
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = holdout) %>%
    # calculate residuals for future use
    mutate(.resid = mpg - .fitted)
  
  # Return the assessment data set with the additional columns
  res
}

# function to return MSE for a specific higher-order polynomial term
poly_mse <- function(i, vfold_data){
  vfold_mod <- vfold_data %>%
    mutate(results = map(splits, holdout_results, i),
           mse = map_dbl(results, ~ mean(.$.resid ^ 2)))
  
  mean(vfold_mod$mse)
}

# split Auto into 10 folds
auto_cv10 <- vfold_cv(data = Auto, v = 10)

cv_mse <- data_frame(terms = seq(from = 1, to = 5),
                     mse_vfold = map_dbl(terms, poly_mse, auto_cv10))
cv_mse
## # A tibble: 5 x 2
##   terms mse_vfold
##   <int>     <dbl>
## 1     1      24.2
## 2     2      19.2
## 3     3      19.3
## 4     4      19.3
## 5     5      18.9

Computational speed of LOOCV

Computational speed of 10-fold CV

k-fold CV in logistic regression

# function to generate assessment statistics for titanic model
holdout_results <- function(splits) {
  # Fit the model to the training set
  mod <- glm(Survived ~ Age * Sex, data = analysis(splits),
             family = binomial)
  
  # Save the heldout observations
  holdout <- assessment(splits)
  
  # `augment` will save the predictions with the holdout data set
  res <- augment(mod, newdata = assessment(splits)) %>% 
    as_tibble() %>%
    mutate(pred = logit2prob(.fitted),
           pred = as.numeric(pred > .5))

  # Return the assessment data set with the additional columns
  res
}

titanic_cv10 <- vfold_cv(data = titanic, v = 10) %>%
  mutate(results = map(splits, holdout_results),
         error_rate = map_dbl(results, ~ mean(.$Survived != .$pred, na.rm = TRUE)))
mean(titanic_cv10$error_rate, na.rm = TRUE)
## [1] 0.2209604

Exercise: k-fold CV